import matplotlib.pyplot as plt

import autograd.numpy as np
import autograd.numpy.random as npr
import autograd.scipy.stats.norm as norm
from autograd import grad
from autograd.misc import flatten
from autograd.misc.optimizers import adam


def init_random_params(scale, layer_sizes, rs=npr.RandomState(0)):
    """Build a list of (weights, biases) tuples, one for each layer."""
    return [
        (
            rs.randn(insize, outsize) * scale,  # weight matrix
            rs.randn(outsize) * scale,
        )  # bias vector
        for insize, outsize in zip(layer_sizes[:-1], layer_sizes[1:])
    ]


def nn_predict(params, inputs, nonlinearity=np.tanh):
    for W, b in params:
        outputs = np.dot(inputs, W) + b
        inputs = nonlinearity(outputs)
    return outputs


def log_gaussian(params, scale):
    flat_params, _ = flatten(params)
    return np.sum(norm.logpdf(flat_params, 0, scale))


def logprob(weights, inputs, targets, noise_scale=0.1):
    predictions = nn_predict(weights, inputs)
    return np.sum(norm.logpdf(predictions, targets, noise_scale))


def build_toy_dataset(n_data=80, noise_std=0.1):
    rs = npr.RandomState(0)
    inputs = np.concatenate([np.linspace(0, 3, num=n_data // 2), np.linspace(6, 8, num=n_data // 2)])
    targets = np.cos(inputs) + rs.randn(n_data) * noise_std
    inputs = (inputs - 4.0) / 2.0
    inputs = inputs[:, np.newaxis]
    targets = targets[:, np.newaxis] / 2.0
    return inputs, targets


if __name__ == "__main__":
    init_scale = 0.1
    weight_prior_variance = 10.0
    init_params = init_random_params(init_scale, layer_sizes=[1, 4, 4, 1])

    inputs, targets = build_toy_dataset()

    def objective(weights, t):
        return -logprob(weights, inputs, targets) - log_gaussian(weights, weight_prior_variance)

    print(grad(objective)(init_params, 0))

    # Set up figure.
    fig = plt.figure(figsize=(12, 8), facecolor="white")
    ax = fig.add_subplot(111, frameon=False)
    plt.show(block=False)

    def callback(params, t, g):
        print(f"Iteration {t} log likelihood {-objective(params, t)}")

        # Plot data and functions.
        plt.cla()
        ax.plot(inputs.ravel(), targets.ravel(), "bx", ms=12)
        plot_inputs = np.reshape(np.linspace(-7, 7, num=300), (300, 1))
        outputs = nn_predict(params, plot_inputs)
        ax.plot(plot_inputs, outputs, "r", lw=3)
        ax.set_ylim([-1, 1])
        plt.draw()
        plt.pause(1.0 / 60.0)

    print("Optimizing network parameters...")
    optimized_params = adam(grad(objective), init_params, step_size=0.01, num_iters=1000, callback=callback)
